def modify_gpt2_model_attention(model, remove_indices):
    for idx in remove_indices:
        original_layer = model.transformer.h[idx]
        def new_forward(
            self,
            hidden_states,
            layer_past=None,
            attention_mask=None,
            head_mask=None,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            use_cache=False,
            output_attentions=False,
        ):
            residual = hidden_states
            attn_output = self.attn(
                self.ln_1(hidden_states),
                layer_past=layer_past,
                attention_mask=attention_mask,
                head_mask=head_mask,
                use_cache=use_cache,
                output_attentions=output_attentions,
            ) [0]     
            
            feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))

            hidden_states = residual + feed_forward_hidden_states + attn_output

            outputs = (hidden_states,)
            if use_cache:
                outputs += (None,)
            if output_attentions:
                outputs += (None,)

            return outputs

        original_layer.forward = new_forward.__get__(original_layer, original_layer.__class__)
    return model


def modify_gpt2_model_attention_remove(model, remove_indices):
    for idx in remove_indices:
        original_layer = model.transformer.h[idx]
        def new_forward(
            self,
            hidden_states,
            layer_past=None,
            attention_mask=None,
            head_mask=None,
            encoder_hidden_states=None,
            encoder_attention_mask=None,
            use_cache=False,
            output_attentions=False,
        ):
            residual = hidden_states
            hidden_states = self.ln_2(hidden_states)
            
            feed_forward_hidden_states = self.mlp(hidden_states)

            hidden_states = residual + feed_forward_hidden_states

            outputs = (hidden_states,)
            if use_cache:
                outputs += (None,)
            if output_attentions:
                outputs += (None,)

            return outputs

        original_layer.forward = new_forward.__get__(original_layer, original_layer.__class__)
    return model